{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# BGE Auto Embedder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "FlagEmbedding provides a high level class `FlagAutoModel` that unify the inference of embedding models. Besides BGE series, it also supports other popular open-source embedding models such as E5, GTE, SFR, etc. In this tutorial, we will have an idea how to use it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "% pip install FlagEmbedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Usage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, import `FlagAutoModel` from FlagEmbedding, and use the `from_finetuned()` function to initialize the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from FlagEmbedding import FlagAutoModel\n",
    "\n",
    "model = FlagAutoModel.from_finetuned(\n",
    "    'BAAI/bge-base-en-v1.5',\n",
    "    query_instruction_for_retrieval=\"Represent this sentence for searching relevant passages: \",\n",
    "    devices=\"cuda:0\",   # if not specified, will use all available gpus or cpu when no gpu available\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then use the model exactly same to `FlagModel` (`BGEM3FlagModel` if using BGE M3, `FlagLLMModel` if using BGE Multilingual Gemma2, `FlagICLModel` if using BGE ICL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.76   0.6714]\n",
      " [0.6177 0.7603]]\n"
     ]
    }
   ],
   "source": [
    "queries = [\"query 1\", \"query 2\"]\n",
    "corpus = [\"passage 1\", \"passage 2\"]\n",
    "\n",
    "# encode the queries and corpus\n",
    "q_embeddings = model.encode_queries(queries)\n",
    "p_embeddings = model.encode_corpus(corpus)\n",
    "\n",
    "# compute the similarity scores\n",
    "scores = q_embeddings @ p_embeddings.T\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Explanation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`FlagAutoModel` use an OrderedDict `MODEL_MAPPING` to store all the supported models configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['bge-en-icl',\n",
       " 'bge-multilingual-gemma2',\n",
       " 'bge-m3',\n",
       " 'bge-large-en-v1.5',\n",
       " 'bge-base-en-v1.5',\n",
       " 'bge-small-en-v1.5',\n",
       " 'bge-large-zh-v1.5',\n",
       " 'bge-base-zh-v1.5',\n",
       " 'bge-small-zh-v1.5',\n",
       " 'bge-large-en',\n",
       " 'bge-base-en',\n",
       " 'bge-small-en',\n",
       " 'bge-large-zh',\n",
       " 'bge-base-zh',\n",
       " 'bge-small-zh',\n",
       " 'e5-mistral-7b-instruct',\n",
       " 'e5-large-v2',\n",
       " 'e5-base-v2',\n",
       " 'e5-small-v2',\n",
       " 'multilingual-e5-large-instruct',\n",
       " 'multilingual-e5-large',\n",
       " 'multilingual-e5-base',\n",
       " 'multilingual-e5-small',\n",
       " 'e5-large',\n",
       " 'e5-base',\n",
       " 'e5-small',\n",
       " 'gte-Qwen2-7B-instruct',\n",
       " 'gte-Qwen2-1.5B-instruct',\n",
       " 'gte-Qwen1.5-7B-instruct',\n",
       " 'gte-multilingual-base',\n",
       " 'gte-large-en-v1.5',\n",
       " 'gte-base-en-v1.5',\n",
       " 'gte-large',\n",
       " 'gte-base',\n",
       " 'gte-small',\n",
       " 'gte-large-zh',\n",
       " 'gte-base-zh',\n",
       " 'gte-small-zh',\n",
       " 'SFR-Embedding-2_R',\n",
       " 'SFR-Embedding-Mistral',\n",
       " 'Linq-Embed-Mistral']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from FlagEmbedding.inference.embedder.model_mapping import AUTO_EMBEDDER_MAPPING\n",
    "\n",
    "list(AUTO_EMBEDDER_MAPPING.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EmbedderConfig(model_class=<class 'FlagEmbedding.inference.embedder.decoder_only.icl.ICLLLMEmbedder'>, pooling_method=<PoolingMethod.LAST_TOKEN: 'last_token'>, trust_remote_code=False, query_instruction_format='<instruct>{}\\n<query>{}')\n"
     ]
    }
   ],
   "source": [
    "print(AUTO_EMBEDDER_MAPPING['bge-en-icl'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Taking a look at the value of each key, which is an object of `EmbedderConfig`. It consists four attributes:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "@dataclass\n",
    "class EmbedderConfig:\n",
    "    model_class: Type[AbsEmbedder]\n",
    "    pooling_method: PoolingMethod\n",
    "    trust_remote_code: bool = False\n",
    "    query_instruction_format: str = \"{}{}\"\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not only the BGE series, it supports other models such as E5 similarly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EmbedderConfig(model_class=<class 'FlagEmbedding.inference.embedder.decoder_only.icl.ICLLLMEmbedder'>, pooling_method=<PoolingMethod.LAST_TOKEN: 'last_token'>, trust_remote_code=False, query_instruction_format='<instruct>{}\\n<query>{}')\n"
     ]
    }
   ],
   "source": [
    "print(AUTO_EMBEDDER_MAPPING['bge-en-icl'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Customization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to use your own models through `FlagAutoModel`, consider the following steps:\n",
    "\n",
    "1. Check the type of your embedding model and choose the appropriate model class, is it an encoder or a decoder?\n",
    "2. What kind of pooling method it uses? CLS token, mean pooling, or last token?\n",
    "3. Does your model needs `trust_remote_code=Ture` to ran?\n",
    "4. Is there a query instruction format for retrieval?\n",
    "\n",
    "After these four attributes are assured, add your model name as the key and corresponding EmbedderConfig as the value to `MODEL_MAPPING`. Now have a try!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
